import numpy as np
import os
import seaborn


def get_pixel_vals(x,b,dim,nbins): # x = [[[1.,2.][2.,3.]]]
    arr = np.zeros((nbins,nbins))
    arrctr = np.ones((nbins,nbins))
    for i in x:
        xval, yval = None, None
        for pixel in range(nbins):
            if xval is None or yval is None:
                if i[1][0] < b[pixel] and xval is None:
                    xval = pixel
                if i[1][1] < b[pixel] and yval is None:
                    yval = pixel
        if xval is None:
            xval = nbins-1
        if yval is None:
            yval = nbins-1
        arrctr[xval][yval] +=1
        arr[xval][yval] += i[0][dim]+1.
    return arr/arrctr, arr, arrctr

def bins(numbins, l, h):
    binsize = (h-l)/numbins
    low = l
    bins = []
    for _ in range(numbins):
        low+=binsize
        bins.append(low)
    return bins

# def main():
#     nbins=256
#     b = bins(nbins, l=-2.5, h=2.5)
#     for logdir in os.listdir('./logs/action_logs/'):
#         dirname = os.path.join("./logs/action_logs/",logdir)
#         files = os.listdir(dirname)
#         delx_vals = []
#         dely_vals = []
#         for file in files:
#             if file.endswith('.npy'):
#                 print(file)
#                 logfile = np.load(os.path.join(dirname,file))
#                 dx,arr,arrctr = get_pixel_vals(logfile, b, 0, nbins)
#                 delx_vals.append(dx.copy())
#                 dy,arr,arrctr = get_pixel_vals(logfile, b, 1, nbins)
#                 dely_vals.append(dy.copy())
        
#         delx_vals = np.array(delx_vals).mean(axis=0)
#         dely_vals = np.array(dely_vals).mean(axis=0)
#         x=seaborn.heatmap(delx_vals,vmin=0., vmax=2.)
#         y=seaborn.heatmap(dely_vals,vmin=0., vmax=2.)

#         x.get_figure().savefig('./x_{}.png'.format(logdir))
#         y.get_figure().savefig('./y_{}.png'.format(logdir))
        
#         np.save("./logs/delx_{}.npy".format(logdir),delx_vals)
#         np.save("./logs/dely_{}.npy".format(logdir),dely_vals)

def make_plots(logdir):
    os.makedirs('./plots/',exist_ok=True)
    nbins=256
    b = bins(nbins, l=-2.5, h=2.5)
    dirname = logdir
    files = os.listdir(dirname)
    delx_vals = []
    dely_vals = []
    for file in files:
        if file.endswith('.npy'):
            print(file)
            logfile = np.load(os.path.join(dirname,file))
            dx,arr,arrctr = get_pixel_vals(logfile, b, 0, nbins)
            delx_vals.append(dx.copy())
            dy,arr,arrctr = get_pixel_vals(logfile, b, 1, nbins)
            dely_vals.append(dy.copy())
    
    delx_vals = np.array(delx_vals).mean(axis=0)
    dely_vals = np.array(dely_vals).mean(axis=0)
    x=seaborn.heatmap(delx_vals,vmin=0., vmax=2.)
    y=seaborn.heatmap(dely_vals,vmin=0., vmax=2.)

    x.get_figure().savefig('./plots/x_{}'.format(os.path.basename(logdir)))
    y.get_figure().savefig('./plots/y_{}'.format(os.path.basename(logdir)))
    
    # np.save("./logs/delx_{}.npy".format(logdir),delx_vals)
    # np.save("./logs/dely_{}.npy".format(logdir),dely_vals)


if __name__ == "__main__":
    main()
